{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary \n",
    "This notebook contains:\n",
    "- torch implementations of a few linear algebra techniques:\n",
    "    - forward- and back-solving \n",
    "    - LDLt decomposition\n",
    "    - QR decomposition via Householder reflections\n",
    "\n",
    "- initial implementations of secure linear regression and Jonathan Bloom's [DASH](https://github.com/jbloom22/DASH/) that leverage PySyft for secure computation.\n",
    "\n",
    "These implementations linear regression and DASH are not currently strictly secure, in that a few final steps are performed on the local worker for now. That's because our implementations of LDLt decomposition, QR decomposition, etc. don't quite work for the PySyft `AdditiveSharingTensor` just yet. They definitely do in principle (because they're compositions of operations the SPDZ supports), but there are still a few details to hammer out.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents \n",
    "[Ordinary least squares regression and LDLt decomposition](#OLSandLDLt)\n",
    "* [LDLt decomposition, forward/back-solving](#LDLt)\n",
    "* [Secure linear regression example](#OLS)\n",
    "\n",
    "[DASH](#dashqr)\n",
    "* [QR decomposition via Householder transforms](#qr)\n",
    "* [DASH example](#dash)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0710 23:13:43.013911 4542494144 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/Users/andrew/.virtualenvs/pysyft/lib/python3.7/site-packages/tf_encrypted-0.5.6-py3.7-macosx-10.14-x86_64.egg/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'\n",
      "W0710 23:13:43.023926 4542494144 deprecation_wrapper.py:119] From /Users/andrew/.virtualenvs/pysyft/lib/python3.7/site-packages/tf_encrypted-0.5.6-py3.7-macosx-10.14-x86_64.egg/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch as th\n",
    "import syft as sy\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up Sandbox...\n",
      "\t- Hooking PyTorch\n",
      "\t- Creating Virtual Workers:\n",
      "\t\t- bob\n",
      "\t\t- theo\n",
      "\t\t- jason\n",
      "\t\t- alice\n",
      "\t\t- andy\n",
      "\t\t- jon\n",
      "\tStoring hook and workers as global variables...\n",
      "\tLoading datasets from SciKit Learn...\n",
      "\t\t- Boston Housing Dataset\n",
      "\t\t- Diabetes Dataset\n",
      "\t\t- Breast Cancer Dataset\n",
      "\t- Digits Dataset\n",
      "\t\t- Iris Dataset\n",
      "\t\t- Wine Dataset\n",
      "\t\t- Linnerud Dataset\n",
      "\tDistributing Datasets Amongst Workers...\n",
      "\tCollecting workers into a VirtualGrid...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "sy.create_sandbox(globals())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <a id='OLSandLDLt'>Ordinary least squared regression and LDLt decomposition</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <a id='LDLt'>LDLt decomposition, forward/back-solving</a>\n",
    "\n",
    "These are torch implementations of basic linear algebra routines we'll use to perform regression (and also in parts of the next section). \n",
    "- Forward/back-solving allows us to solve triangular linear systems efficiently and stably.\n",
    "- LDLt decomposition lets us write symmetric matrics as a product LDL^t where L is lower-triangular and D is diagonal (^t denotes transpose). It performs a role similar to Cholesky decomposition (which is normally available as method of a torch tensor), but doesn't require computing square roots. This makes makes LDLt a better fit for the secure setting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _eye(n):\n",
    "    \"\"\"th.eye doesn't seem to work after hooking torch, so just adding\n",
    "    a workaround for now.\n",
    "    \"\"\"\n",
    "    return th.FloatTensor(np.eye(n))\n",
    "\n",
    "\n",
    "def ldlt_decomposition(x):\n",
    "    \"\"\"Decompose the square, symmetric, full-rank matrix X as X = LDL^t, where \n",
    "        - L is upper triangular\n",
    "        - D is diagonal.\n",
    "    \"\"\"\n",
    "    n, _ = x.shape\n",
    "    l, diag = _eye(n), th.zeros(n).float()\n",
    "\n",
    "    for j in range(n):\n",
    "        diag[j] = x[j, j] - (th.sum((l[j, :j] ** 2) * diag[:j]))\n",
    "        for i in range(j + 1, n):\n",
    "            l[i, j] = (x[i, j] - th.sum(diag[:j] * l[i, :j] * l[j, :j])) / diag[j]\n",
    "\n",
    "    return l, th.diag(diag), l.transpose(0, 1)\n",
    "\n",
    "\n",
    "def back_solve(u, y):\n",
    "    \"\"\"Solve Ux = y for U a square, upper triangular matrix of full rank\"\"\"\n",
    "    n = u.shape[0]\n",
    "    x = th.zeros(n)\n",
    "    for i in range(n - 1, -1, -1):\n",
    "        x[i] = (y[i] - th.sum(u[i, i+1:] * x[i+1:])) / u[i, i]\n",
    "\n",
    "    return x.reshape(-1, 1)\n",
    "\n",
    "\n",
    "def forward_solve(l, y):\n",
    "    \"\"\"Solve Lx = y for L a square, lower triangular matrix of full rank.\"\"\"\n",
    "    n = l.shape[0]\n",
    "    x = th.zeros(n)\n",
    "    for i in range(0, n):\n",
    "        x[i] = (y[i] - th.sum(l[i, :i] * x[:i])) / l[i, i]\n",
    "\n",
    "    return x.reshape(-1, 1)\n",
    "\n",
    "\n",
    "def invert_triangular(t, upper=True):\n",
    "    \"\"\"\n",
    "    Invert by repeated forward/back-solving.\n",
    "    TODO: -Could be made more efficient with vectorized implementation of forward/backsolve\n",
    "          -detection and validation around triangularity/squareness\n",
    "    \"\"\"\n",
    "    solve = back_solve if upper else forward_solve\n",
    "    t_inv = th.zeros_like(t)\n",
    "    n = t.shape[0]\n",
    "    for i in range(n):\n",
    "        e = th.zeros(n, 1)\n",
    "        e[i] = 1.\n",
    "        t_inv[:, [i]] = solve(t, e)\n",
    "    return t_inv\n",
    "\n",
    "\n",
    "def solve_symmetric(a, y):\n",
    "    \"\"\"Solve the linear system Ax = y where A is a symmetric matrix of full rank.\"\"\"\n",
    "    l, d, lt = ldlt_decomposition(a)\n",
    "    \n",
    "    # TODO: more efficient to just extract diagonal of d as 1D vector and scale?\n",
    "    x_ = forward_solve(l.mm(d), y)\n",
    "    return back_solve(lt, x_)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PASSED for tensor([[1., 2., 3.],\n",
      "        [2., 1., 2.],\n",
      "        [3., 2., 1.]])\n",
      "PASSED for tensor([[1., 2., 3.],\n",
      "        [2., 1., 2.],\n",
      "        [3., 2., 1.]]), tensor([[1.],\n",
      "        [2.],\n",
      "        [3.]])\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Basic tests for LDLt decomposition.\n",
    "\"\"\"\n",
    "\n",
    "def _assert_small(x, failure_msg=None, threshold=1E-5):\n",
    "    norm = x.norm()\n",
    "    assert norm < threshold, failure_msg\n",
    "\n",
    "\n",
    "def test_ldlt_case(a):\n",
    "    l, d, lt = ldlt_decomposition(a)\n",
    "    _assert_small(l - lt.transpose(0, 1))\n",
    "    _assert_small(l.mm(d).mm(lt) - a, 'Decomposition is inaccurate.')\n",
    "    _assert_small(l - th.tril(l), 'L is not lower triangular.')\n",
    "    _assert_small(th.triu(th.tril(d)) - d, 'D is not diagonal.')\n",
    "    print(f'PASSED for {a}')\n",
    "    \n",
    "\n",
    "def test_solve_symmetric_case(a, x):\n",
    "    y = a.mm(x)\n",
    "    _assert_small(solve_symmetric(a, y) - x)\n",
    "    print(f'PASSED for {a}, {x}')\n",
    "\n",
    "    \n",
    "a = th.tensor([[1, 2, 3],\n",
    "               [2, 1, 2],\n",
    "               [3, 2, 1]]).float()\n",
    "\n",
    "x = th.tensor([1, 2, 3]).float().reshape(-1, 1)\n",
    "\n",
    "test_ldlt_case(a)\n",
    "test_solve_symmetric_case(a, x)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <a id='OLS'>Secure linear regression example</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Problem\n",
    "We're solving \n",
    "$$ \\min_\\beta \\|X \\beta - y\\|_2 $$\n",
    "in the situation where the data $(X, y)$ is horizontally partitioned (each worker $w$ owns chunks $X_w, y_w$ of the rows of $X$ and $y$).\n",
    "\n",
    "#### Goals\n",
    "We want to do this \n",
    "* securely \n",
    "* without network overhead or MPC-related costs that scale with the number of rows of $X$. \n",
    "\n",
    "#### Plan\n",
    "\n",
    "1. (**local plaintext compression**): each worker locally computes $X_w^t X_w$ and $X_w^t y_w$ in plain text. This is the only step that depends on the number of rows of X, and it's performed in plaintext.\n",
    "2. (**secure summing**): securely compute the sums $$\\begin{align}X^t X &= \\sum_w X^t_w X_w \\\\ X^t y &= \\sum_w X^t_w y_w \\end{align}$$ as an AdditiveSharingTensor. Some worker or other party (here the local worker) will have a pointers to those two AdditiveSharingTensors.\n",
    "3. (**secure solve**): We can then solve $X^tX\\beta = X^ty$ for $\\beta$ by a sequence of operations on those pointers (specifically, we apply `solve_symmetric` defined above).\n",
    "\n",
    "#### Example data: \n",
    "The correct $\\beta$ is $[1, 2, -1]$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = th.tensor(10 * np.random.randn(30000, 3))\n",
    "y = (X[:, 0] + 2 * X[:, 1] - X[:, 2]).reshape(-1, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the data into chunks and send a chunk to each worker, storing pointers to chunks in two `MultiPointerTensor`s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "workers = [alice, bob, theo]\n",
    "crypto_provider = jon\n",
    "chunk_size = int(X.shape[0] / len(workers))\n",
    "\n",
    "\n",
    "def _get_chunk_pointers(data, chunk_size, workers):\n",
    "    return [\n",
    "        data[(i * chunk_size):((i+1)*chunk_size), :].send(worker)\n",
    "        for i, worker in enumerate(workers)\n",
    "    ] \n",
    "\n",
    "\n",
    "X_ptrs = sy.MultiPointerTensor(\n",
    "    children=_get_chunk_pointers(X, chunk_size, workers))\n",
    "y_ptrs = sy.MultiPointerTensor(\n",
    "    children=_get_chunk_pointers(y, chunk_size, workers))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### local compression\n",
    "This is the only step that depends on the number of rows of $X, y$, and it's performed locally on each worker in plain text. The result is two `MultiPointerTensor`s with pointers to each workers' summand of $X^tX$ (or $X^ty$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "Xt_ptrs = X_ptrs.transpose(0, 1)\n",
    "\n",
    "XtX_summand_ptrs = Xt_ptrs.mm(X_ptrs)\n",
    "Xty_summand_ptrs = Xt_ptrs.mm(y_ptrs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### secure sum\n",
    "We add those summands up in two steps:\n",
    "- share each summand among all other workers\n",
    "- move the resulting pointers to one place (here just the local worker) and add 'em up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _generate_shared_summand_pointers(\n",
    "        summand_ptrs, \n",
    "        workers, \n",
    "        crypto_provider):\n",
    "\n",
    "    for worker_id, summand_pointer in summand_ptrs.child.items():\n",
    "        shared_summand_pointer = summand_pointer.fix_precision().share(\n",
    "            *workers, crypto_provider=crypto_provider)\n",
    "        yield shared_summand_pointer.get()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "XtX_shared = sum(\n",
    "    _generate_shared_summand_pointers(\n",
    "        XtX_summand_ptrs, workers, crypto_provider))\n",
    "\n",
    "Xty_shared = sum(_generate_shared_summand_pointers(\n",
    "    Xty_summand_ptrs, workers, crypto_provider))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### secure solve\n",
    "The coefficient $\\beta$ is the solution to\n",
    "$$X^t X \\beta = X^t y$$\n",
    "\n",
    "We solve for $\\beta$ using `solve_symmetric`. Critically, this is a composition of linear operations that should be supported by `AdditiveSharingTensor`. Unlike the classic Cholesky decomposition, the $LDL^t$ decomposition in step 1 does not involve taking square roots, which would be challenging.\n",
    "\n",
    "\n",
    "**TODO**: there's still some additional work required to get `solve_symmetric` working for `AdditiveSharingTensor`, so we're performing the final linear solve publicly for now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = solve_symmetric(XtX_shared.get().float_precision(), Xty_shared.get().float_precision())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.0000],\n",
       "        [ 2.0000],\n",
       "        [-1.0000]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "beta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <a id='dashqr'>DASH and QR-decomposition</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <a id='qr'>QR decomposition</a>\n",
    "\n",
    "A $m \\times n$ real matrix $A$ with $m \\geq n$ can be written as $$A = QR$$ for $Q$ orthogonal and $R$ upper triangular. This is helpful in solving systems of equations, among other things. It is also central to the compression idea of [DASH](https://arxiv.org/pdf/1901.09531.pdf). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Full QR decomposition via Householder transforms, \n",
    "following Numerical Linear Algebra (Trefethen and Bau).\n",
    "\"\"\"\n",
    "\n",
    "def _apply_householder_transform(a, v):\n",
    "    return a - 2 * v.mm(v.transpose(0, 1).mm(a))\n",
    "\n",
    "\n",
    "def _build_householder_matrix(v):\n",
    "    n = v.shape[0]\n",
    "    u = v / v.norm()\n",
    "    return _eye(n) - 2 * u.mm(u.transpose(0, 1))\n",
    "\n",
    "\n",
    "def _householder_qr_step(a):\n",
    "    x = a[:, 0].reshape(-1, 1)\n",
    "    alpha = x.norm()\n",
    "    u = x.copy()\n",
    "\n",
    "    # note: can get better stability by multiplying by sign(u[0, 0])\n",
    "    # (where sign(0) = 1); is this supported in the secure context?\n",
    "    u[0, 0] += u.norm()\n",
    "    \n",
    "    # is there a simple way of getting around computing the norm twice?\n",
    "    u /= u.norm()\n",
    "    a = _apply_householder_transform(a, u)\n",
    "\n",
    "    return a, u\n",
    "\n",
    "\n",
    "def _recover_q(householder_vectors):\n",
    "    \"\"\"\n",
    "    Build the matrix Q from the Householder transforms.\n",
    "    \"\"\"\n",
    "    n = len(householder_vectors)\n",
    "\n",
    "    def _apply_transforms(x):\n",
    "        \"\"\"Trefethen and Bau, Algorithm 10.3\"\"\"\n",
    "        for k in range(n-1, -1, -1):\n",
    "            x[k:, :] = _apply_householder_transform(\n",
    "                x[k:, :], \n",
    "                householder_vectors[k])\n",
    "        return x\n",
    "\n",
    "    m = householder_vectors[0].shape[0]\n",
    "    n = len(householder_vectors)\n",
    "    q = th.zeros(m, m)\n",
    "    \n",
    "    # Determine q by evaluating it on a basis\n",
    "    for i in range(m):\n",
    "        e = th.zeros(m, 1)\n",
    "        e[i] = 1.\n",
    "        q[:, [i]] = _apply_transforms(e)\n",
    "    \n",
    "    return q\n",
    "\n",
    "\n",
    "def qr(a, return_q=True):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        a: shape (m, n), m >= n\n",
    "        return_q: bool, whether to reconstruct q \n",
    "    Returns:\n",
    "        orthogonal q of shape (m, m) (None if return_q is False)\n",
    "        upper-triangular of shape (m, n)\n",
    "    \"\"\"\n",
    "    m, n = a.shape\n",
    "    assert m >= n, \\\n",
    "        f\"Passed a of shape {a.shape}, must have a.shape[0] >= a.shape[1]\"\n",
    "\n",
    "    r = a.copy()\n",
    "    householder_unit_normal_vectors = []\n",
    "\n",
    "    for k in range(n):\n",
    "        r[k:, k:], u = _householder_qr_step(r[k:, k:])\n",
    "        householder_unit_normal_vectors.append(u)\n",
    "    if return_q:\n",
    "        q = _recover_q(householder_unit_normal_vectors)\n",
    "    else:\n",
    "        q = None\n",
    "    return q, r\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PASSED for \n",
      "tensor([[1., 0., 1.],\n",
      "        [1., 1., 0.],\n",
      "        [0., 1., 1.]])\n",
      "\n",
      "PASSED for \n",
      "tensor([[1., 0., 1.],\n",
      "        [1., 1., 0.],\n",
      "        [0., 1., 1.],\n",
      "        [1., 1., 1.]])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Basic tests for QR decomposition\n",
    "\"\"\"\n",
    "\n",
    "def _test_qr_case(a): \n",
    "    \n",
    "    q, r = qr(a)\n",
    "    \n",
    "    # actually have QR = A\n",
    "    _assert_small(q.mm(r) - a, \"QR = A failed\")\n",
    "\n",
    "    # Q is orthogonal\n",
    "    m, _ = a.shape\n",
    "    _assert_small(\n",
    "        q.mm(q.transpose(0, 1)) - _eye(m),\n",
    "        \"QQ^t = I failed\"\n",
    "    )\n",
    "    \n",
    "    # R is upper triangular\n",
    "    lower_triangular_entries = th.tensor([\n",
    "        r[i, j].item() for i in range(r.shape[0]) \n",
    "             for j in range(i)])\n",
    "\n",
    "    _assert_small(\n",
    "        lower_triangular_entries,\n",
    "        \"R is not upper triangular\"\n",
    "    )\n",
    "\n",
    "    print(f\"PASSED for \\n{a}\\n\")\n",
    "\n",
    "\n",
    "def test_qr():\n",
    "    _test_qr_case(\n",
    "        th.tensor([[1, 0, 1],\n",
    "                   [1, 1, 0],\n",
    "                   [0, 1, 1]]).float()\n",
    "    )\n",
    "\n",
    "    _test_qr_case(\n",
    "        th.tensor([[1, 0, 1],\n",
    "                   [1, 1, 0],\n",
    "                   [0, 1, 1],\n",
    "                   [1, 1, 1],]).float()\n",
    "    )\n",
    "    \n",
    "test_qr()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <a id='dash'>DASH implementation</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We follow https://github.com/jbloom22/DASH/. \n",
    "\n",
    "The overall structure is roughly analogous to the linear regression example above.\n",
    "\n",
    "- There's a local compression step that's performed separately on each worker in plaintext.\n",
    "- We leverage PySyft's SMCP features to perform secure summation.\n",
    "- For now, the last few steps are performed by a single player (the local worker). \n",
    "    - Again, this could be performed securely, but there are still a few hitches with getting our torch implementation of QR decomposition to work for an `AdditiveSharingTensor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _generate_worker_data_pointers(\n",
    "        n, m, k, worker,\n",
    "        beta_correct, gamma_correct, epsilon=0.01\n",
    "):\n",
    "    \"\"\"\n",
    "    Return pointers to worker-level data.\n",
    "    Args:\n",
    "        n: number of rows\n",
    "        m: number of transient\n",
    "        k: number of covariates\n",
    "        beta_correct: coefficients for transient features (tensor of shape (m, 1))\n",
    "        gamma_correct: coefficients for covariates (tensor of shape (k, 1))\n",
    "        epsilon: scale of noise added to response\n",
    "    Return:\n",
    "        y, X, C: pointers to response, transients, and covariates\n",
    "    \"\"\"\n",
    "    X = th.randn(n, m).send(worker)\n",
    "    C = th.randn(n, k).send(worker)\n",
    "    \n",
    "    y = (X.mm(beta_correct.copy().send(worker)).reshape(-1, 1) + \n",
    "         C.mm(gamma_correct.copy().send(worker)).reshape(-1, 1))\n",
    "\n",
    "    y += (epsilon * th.randn(n, 1)).send(worker)\n",
    "\n",
    "    return y, X, C\n",
    "\n",
    "\n",
    "def _dot(x):\n",
    "    return (x * x).sum(dim=0).reshape(-1, 1)\n",
    "\n",
    "\n",
    "def _secure_sum(worker_level_pointers, workers, crypto_provider):\n",
    "    \"\"\"\n",
    "    Securely add up an interable of pointers to (same-sized) tensors.\n",
    "    Args:\n",
    "        worker_level_pointers: iterable of pointer tensors\n",
    "        workers: list of workers\n",
    "        crypto_provider: worker\n",
    "    Returns:\n",
    "        AdditiveSharingTensor shared among workers\n",
    "    \"\"\"\n",
    "    return sum([\n",
    "        p.fix_precision(precision_fractional=10).share(*workers, crypto_provider=crypto_provider).get()\n",
    "        for p in worker_level_pointers\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dash_example_secure(\n",
    "        workers, crypto_provider,\n",
    "        n_samples_by_worker, m, k,\n",
    "        beta_correct, gamma_correct,        \n",
    "        epsilon=0.01\n",
    "):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        workers: list of workers\n",
    "        crypto_provider: worker\n",
    "        n_samples_by_worker: dict mapping worker ids to ints (number of rows of data)\n",
    "        m: number of transients\n",
    "        k: number of covariates\n",
    "        beta_correct: coefficient for transient features\n",
    "        gamma_correct: coefficient for covariates\n",
    "        epsilon: scale of noise added to response\n",
    "    Returns:\n",
    "        beta, sigma, tstat, pval: coefficient of transients and accompanying statistics\n",
    "    \"\"\"\n",
    "    # Generate each worker's data\n",
    "    worker_data_pointers = {\n",
    "        p: _generate_worker_data_pointers(\n",
    "            n, m, k, workers[p],\n",
    "            beta_correct, gamma_correct,\n",
    "            epsilon=epsilon)\n",
    "        for p, n in n_samples_by_worker.items()\n",
    "    }\n",
    "\n",
    "    # to be populated with pointers to results of local, worker-level computations\n",
    "    Ctys, CtXs, yys, Xys, XXs, Rs = {}, {}, {}, {}, {}, {}\n",
    "\n",
    "    def _sum(pointers):\n",
    "        return _secure_sum(pointers, list(players.values()), crypto_provider) \n",
    "\n",
    "    # worker-level compression step\n",
    "    for p, (y, X, C) in worker_data_pointers.items():\n",
    "        \n",
    "        # perform worker-level compression step\n",
    "        yys[p] = y.norm()\n",
    "        Xys[p] = X.transpose(0, 1).mm(y)\n",
    "        XXs[p] = _dot(X)\n",
    "        \n",
    "        Ctys[p] = C.transpose(0, 1).mm(y)\n",
    "        CtXs[p] = C.transpose(0, 1).mm(X)\n",
    "        _, R_full = qr(C, return_q=False)\n",
    "        Rs[p] = R_full[:k, :]\n",
    "    \n",
    "    # Perform secure sum \n",
    "    # - We're returning result to the local worker and computing there for the rest\n",
    "    #   of the way, but should be possible to compute via SMPC (on a pointers to AdditiveSharingTensors)\n",
    "    # - still afew minor-looking issues with implementing invert_triangular/qr for \n",
    "    #   AdditiveSharingTensor\n",
    "    yy = _sum(yys.values()).get().float_precision()\n",
    "    Xy = _sum(Xys.values()).get().float_precision()\n",
    "    XX = _sum(XXs.values()).get().float_precision()\n",
    "    \n",
    "    Cty = _sum(Ctys.values()).get().float_precision()\n",
    "    CtX = _sum(CtXs.values()).get().float_precision()\n",
    "\n",
    "    # Rest is done publicly on the local worker for now\n",
    "    _, R_public = qr(\n",
    "        th.cat([R.get() for R in Rs.values()], dim=0),\n",
    "        return_q=False)\n",
    "\n",
    "    invR_public = invert_triangular(R_public[:k, :])\n",
    "\n",
    "    Qty = invR_public.transpose(0, 1).mm(Cty)\n",
    "    QtX = invR_public.transpose(0, 1).mm(CtX)\n",
    "\n",
    "    QtXQty = QtX.transpose(0, 1).mm(Qty)\n",
    "    QtyQty = _dot(Qty)\n",
    "    QtXQtX = _dot(QtX)\n",
    "\n",
    "    yyq = yy - QtyQty\n",
    "    Xyq = Xy - QtXQty\n",
    "    XXq = XX - QtXQtX\n",
    "\n",
    "    d = sum(n_samples_by_worker.values()) - k - 1\n",
    "    beta = Xyq / XXq\n",
    "    sigma = ((yyq / XXq - (beta ** 2)) / d).abs() ** 0.5\n",
    "    tstat = beta / sigma\n",
    "    pval = 2 * stats.t.cdf(-abs(tstat), d)\n",
    "\n",
    "    return beta, sigma, tstat, pval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[1.0198],\n",
       "         [0.9758],\n",
       "         [0.9643],\n",
       "         [1.0004],\n",
       "         [1.0150],\n",
       "         [0.9944],\n",
       "         [0.9961],\n",
       "         [1.0318],\n",
       "         [1.0002],\n",
       "         [0.9830],\n",
       "         [0.9790],\n",
       "         [0.9926],\n",
       "         [1.0008],\n",
       "         [0.9697],\n",
       "         [0.9954],\n",
       "         [0.9995],\n",
       "         [0.9960],\n",
       "         [0.9767],\n",
       "         [1.0059],\n",
       "         [0.9838],\n",
       "         [0.9911],\n",
       "         [1.0179],\n",
       "         [1.0080],\n",
       "         [0.9829],\n",
       "         [0.9937],\n",
       "         [0.9819],\n",
       "         [1.0188],\n",
       "         [0.9811],\n",
       "         [0.9971],\n",
       "         [0.9866],\n",
       "         [1.0117],\n",
       "         [0.9953],\n",
       "         [0.9966],\n",
       "         [0.9952],\n",
       "         [0.9957],\n",
       "         [0.9860],\n",
       "         [1.0206],\n",
       "         [0.9928],\n",
       "         [0.9925],\n",
       "         [1.0149],\n",
       "         [0.9587],\n",
       "         [0.9851],\n",
       "         [1.0102],\n",
       "         [1.0127],\n",
       "         [1.0143],\n",
       "         [1.0050],\n",
       "         [0.9926],\n",
       "         [0.9646],\n",
       "         [0.9966],\n",
       "         [0.9906],\n",
       "         [1.0212],\n",
       "         [0.9948],\n",
       "         [1.0253],\n",
       "         [0.9936],\n",
       "         [0.9834],\n",
       "         [0.9770],\n",
       "         [0.9885],\n",
       "         [0.9890],\n",
       "         [0.9954],\n",
       "         [0.9900],\n",
       "         [0.9795],\n",
       "         [0.9657],\n",
       "         [0.9836],\n",
       "         [1.0042],\n",
       "         [0.9957],\n",
       "         [0.9929],\n",
       "         [1.0127],\n",
       "         [0.9869],\n",
       "         [0.9969],\n",
       "         [1.0172],\n",
       "         [1.0030],\n",
       "         [0.9844],\n",
       "         [1.0121],\n",
       "         [1.0071],\n",
       "         [0.9954],\n",
       "         [0.9936],\n",
       "         [0.9954],\n",
       "         [1.0070],\n",
       "         [0.9928],\n",
       "         [0.9900],\n",
       "         [0.9970],\n",
       "         [0.9992],\n",
       "         [0.9851],\n",
       "         [0.9942],\n",
       "         [0.9710],\n",
       "         [0.9799],\n",
       "         [0.9675],\n",
       "         [1.0246],\n",
       "         [1.0085],\n",
       "         [0.9906],\n",
       "         [0.9984],\n",
       "         [1.0182],\n",
       "         [0.9805],\n",
       "         [0.9905],\n",
       "         [1.0034],\n",
       "         [0.9965],\n",
       "         [0.9983],\n",
       "         [0.9973],\n",
       "         [0.9872],\n",
       "         [0.9937]]), tensor([[0.0032],\n",
       "         [0.0032],\n",
       "         [0.0031],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0031],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0031],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0031],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0031],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032],\n",
       "         [0.0032]]), tensor([[319.5643],\n",
       "         [309.3729],\n",
       "         [306.5470],\n",
       "         [315.1175],\n",
       "         [318.3896],\n",
       "         [313.9088],\n",
       "         [314.0636],\n",
       "         [322.6956],\n",
       "         [315.1400],\n",
       "         [311.2341],\n",
       "         [310.3611],\n",
       "         [313.5351],\n",
       "         [315.4006],\n",
       "         [307.5811],\n",
       "         [313.6877],\n",
       "         [314.7977],\n",
       "         [314.0446],\n",
       "         [309.6685],\n",
       "         [316.4051],\n",
       "         [311.1368],\n",
       "         [313.2263],\n",
       "         [319.4759],\n",
       "         [316.2847],\n",
       "         [310.5904],\n",
       "         [313.5548],\n",
       "         [311.4909],\n",
       "         [319.7083],\n",
       "         [310.5916],\n",
       "         [314.3991],\n",
       "         [311.7694],\n",
       "         [317.7445],\n",
       "         [314.2973],\n",
       "         [313.8512],\n",
       "         [313.7931],\n",
       "         [313.8863],\n",
       "         [311.7593],\n",
       "         [320.1660],\n",
       "         [313.1064],\n",
       "         [312.8649],\n",
       "         [318.5549],\n",
       "         [305.3120],\n",
       "         [311.6055],\n",
       "         [317.7577],\n",
       "         [318.0811],\n",
       "         [318.6992],\n",
       "         [316.0238],\n",
       "         [313.5295],\n",
       "         [306.4926],\n",
       "         [314.0880],\n",
       "         [312.5821],\n",
       "         [319.9753],\n",
       "         [313.6790],\n",
       "         [321.0503],\n",
       "         [313.4674],\n",
       "         [310.9974],\n",
       "         [310.1643],\n",
       "         [312.6641],\n",
       "         [312.5327],\n",
       "         [313.7359],\n",
       "         [312.9331],\n",
       "         [310.3143],\n",
       "         [307.3026],\n",
       "         [311.3515],\n",
       "         [316.4857],\n",
       "         [314.6939],\n",
       "         [313.4534],\n",
       "         [318.1405],\n",
       "         [312.0404],\n",
       "         [314.7274],\n",
       "         [318.8090],\n",
       "         [315.5815],\n",
       "         [311.3936],\n",
       "         [317.9680],\n",
       "         [317.0735],\n",
       "         [314.1143],\n",
       "         [313.6470],\n",
       "         [313.5154],\n",
       "         [316.9404],\n",
       "         [313.7480],\n",
       "         [312.7427],\n",
       "         [314.7064],\n",
       "         [314.5428],\n",
       "         [311.3454],\n",
       "         [313.8516],\n",
       "         [308.4349],\n",
       "         [310.3224],\n",
       "         [307.0361],\n",
       "         [320.6432],\n",
       "         [317.1022],\n",
       "         [312.7644],\n",
       "         [314.5857],\n",
       "         [319.2417],\n",
       "         [310.1894],\n",
       "         [312.5483],\n",
       "         [315.9466],\n",
       "         [314.3875],\n",
       "         [314.4272],\n",
       "         [314.6446],\n",
       "         [312.2668],\n",
       "         [313.8906]]), array([[0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.],\n",
       "        [0.]]))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "players = {\n",
    "    worker.id: worker \n",
    "    for worker in [alice, bob, theo]\n",
    "}\n",
    "\n",
    "# de\n",
    "n_samples_by_player = {\n",
    "    alice.id: 100000,\n",
    "    bob.id: 200000,\n",
    "    theo.id: 100000\n",
    "}\n",
    "\n",
    "crypto_provider = jon\n",
    "\n",
    "m = 100\n",
    "k = 3\n",
    "d = sum(n_samples_by_player.values()) - k - 1\n",
    "\n",
    "\n",
    "beta_correct = th.ones(m, 1)\n",
    "gamma_correct = th.ones(k, 1)\n",
    "\n",
    "dash_example_secure(\n",
    "    players, crypto_provider, \n",
    "    n_samples_by_player, m, k, \n",
    "    beta_correct, gamma_correct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
